{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YS3NA-i6nAFC"
      },
      "outputs": [],
      "source": [
        "##### Copyright 2022 The TensorFlow Authors.\n",
        "\n",
        "\n",
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7SN5USFEIIK3"
      },
      "source": [
        "# Warm-start embedding layer matrix"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Aojnnc7sXrab"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/tutorials/text/warmstart_embedding_matrix\">\n",
        "    <img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />\n",
        "    View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/master/site/en/tutorials/text/warmstart_embedding_matrix.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />\n",
        "    Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/tensorflow/docs/blob/master/site/en/tutorials/text/warmstart_embedding_matrix.ipynb\">\n",
        "    <img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />\n",
        "    View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/tutorials/text/warmstart_embedding_matrix.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q6mJg1g3apaz"
      },
      "source": [
        "This tutorial shows how to \"warm-start\" training using the [`tf.keras.utils.warmstart_embedding_matrix`](https://www.tensorflow.org/api_docs/python/tf/keras/utils/warmstart_embedding_matrix) API for text sentiment classification when changing vocabulary.\n",
        "\n",
        "You will begin by training a simple Keras model with a base vocabulary, and then, after updating the vocabulary, continue training the model. This is referred to as \"warm-start\" training, for which you'll need to remap the text-embedding matrix for the new vocabulary."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "TZhifmcDwJTf"
      },
      "source": [
        "## Embedding matrix\n",
        "\n",
        "Embeddings provide a way to use an efficient, dense representation in which similar vocabulary tokens have a similar encoding. They are trainable parameters (weights learned by the model during training, in the same way a model learns weights for a dense layer). It is common to have embeddings that are 8-dimensional for small datasets, and up to 1024-dimensions when working with large datasets. A higher dimensional embedding can capture fine-grained relationships between words, but can take more data to learn."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2voNac7BwJ-g"
      },
      "source": [
        "### Vocabulary\n",
        "\n",
        "The set of unique words is referred to as the vocabulary. To build a text model you need to choose a fixed vocabulary. Typically you you build the vocabulary from the most common words in a dataset. The vocabulary allows us to represent each piece of text by a sequence of ID's that you can lookup in the embedding matrix. Vocabulary allows us to represent each piece of text by the specific words that appear in it."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JuBkjGwtwKiv"
      },
      "source": [
        "### Why warm-start an embedding matrix?\n",
        "\n",
        "A model is trained with a set of embeddings that represents a given vocabulary. If the model needs to be updated or improved you can train to convergence significantly faster by reusing weights from a previous run. Using the embedding matrix from a previous run is more difficult. The problem is that any change to the vocabulary invalidates the word to id mapping.\n",
        "\n",
        "The `tf.keras.utils.warmstart_embedding_matrix` solves this problem by creating an embedding matrix for a new vocabulary from an embedding martix from a base vocabulary. Where a word exists in both vocabularies the base embedding vector is copied into the correct location in the new embedding matrix. This allows you to warm-start training after any change in the size or order of the vocabulary."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SZUQErGewZxE"
      },
      "source": [
        "## Setup"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "BfPukisbG_Yu"
      },
      "outputs": [],
      "source": [
        "!pip install --pre -U \"tensorflow>2.10\"  # Requires 2.11"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RutaI-Tpev3T"
      },
      "outputs": [],
      "source": [
        "import io\n",
        "import numpy as np\n",
        "import os\n",
        "import re\n",
        "import shutil\n",
        "import string\n",
        "import tensorflow as tf\n",
        "\n",
        "from tensorflow.keras import Model\n",
        "from tensorflow.keras.layers import Dense, Embedding, GlobalAveragePooling1D\n",
        "from tensorflow.keras.layers import TextVectorization"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SBFctV8-JZOc"
      },
      "source": [
        "### Load the dataset\n",
        "The tutorial uses the [Large Movie Review Dataset](http://ai.stanford.edu/~amaas/data/sentiment/). You will train a sentiment classifier model on this dataset and in the process learn embeddings from scratch. Refer to [Loading text tutorial](https://www.tensorflow.org/tutorials/load_data/text) to learn more.  \n",
        "\n",
        "Download the dataset using Keras file utility and review the directories."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aPO4_UmfF0KH"
      },
      "outputs": [],
      "source": [
        "url = \"https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\"\n",
        "\n",
        "dataset = tf.keras.utils.get_file(\n",
        "    \"aclImdb_v1.tar.gz\", url, untar=True, cache_dir=\".\", cache_subdir=\"\"\n",
        ")\n",
        "\n",
        "dataset_dir = os.path.join(os.path.dirname(dataset), \"aclImdb\")\n",
        "os.listdir(dataset_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eY6yROZNKvbd"
      },
      "source": [
        "The `train/` directory has `pos` and `neg` folders with movie reviews labelled as positive and negative respectively. You will use reviews from `pos` and `neg` folders to train a binary classification model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9-iOHJGN6SDu"
      },
      "outputs": [],
      "source": [
        "train_dir = os.path.join(dataset_dir, \"train\")\n",
        "os.listdir(train_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9O59BdioK8jY"
      },
      "source": [
        "The `train` directory also contains additional folders which should be removed before creating the training set."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1_Vfi9oWMSh-"
      },
      "outputs": [],
      "source": [
        "remove_dir = os.path.join(train_dir, \"unsup\")\n",
        "shutil.rmtree(remove_dir)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "oFoJjiEyJz9u"
      },
      "source": [
        "Next, create a `tf.data.Dataset` using `tf.keras.utils.text_dataset_from_directory`. You can read more about using this utility in this [text classification tutorial](https://www.tensorflow.org/tutorials/keras/text_classification). \n",
        "\n",
        "Use the `train` directory to create the training and validation sets with a split of 20% for validation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ItYD3TLkCOP1"
      },
      "outputs": [],
      "source": [
        "batch_size = 1024\n",
        "seed = 123\n",
        "train_ds = tf.keras.utils.text_dataset_from_directory(\n",
        "    \"aclImdb/train\",\n",
        "    batch_size=batch_size,\n",
        "    validation_split=0.2,\n",
        "    subset=\"training\",\n",
        "    seed=seed,\n",
        ")\n",
        "val_ds = tf.keras.utils.text_dataset_from_directory(\n",
        "    \"aclImdb/train\",\n",
        "    batch_size=batch_size,\n",
        "    validation_split=0.2,\n",
        "    subset=\"validation\",\n",
        "    seed=seed,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FHV2pchDhzDn"
      },
      "source": [
        "### Configure the dataset for performance\n",
        "\n",
        "You can learn more about `Dataset.cache` and `Dataset.prefetch`, as well as how to cache data to disk in the [data performance guide](https://www.tensorflow.org/guide/data_performance)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Oz6k1IW7h1TO"
      },
      "outputs": [],
      "source": [
        "AUTOTUNE = tf.data.AUTOTUNE\n",
        "\n",
        "train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)\n",
        "val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aGicgV5qT0wh"
      },
      "source": [
        "## Text preprocessing"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N6NZSqIIoU0Y"
      },
      "source": [
        "Next, define the dataset preprocessing steps required for your sentiment classification model. Initialize a `layers.TextVectorization` layer with the desired parameters to vectorize movie reviews. You can learn more about using this layer in the [Text Classification](https://www.tensorflow.org/tutorials/keras/text_classification) tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2MlsXzo-ZlfK"
      },
      "outputs": [],
      "source": [
        "# Create a custom standardization function to strip HTML break tags '<br />'.\n",
        "def custom_standardization(input_data):\n",
        "    lowercase = tf.strings.lower(input_data)\n",
        "    stripped_html = tf.strings.regex_replace(lowercase, \"<br />\", \" \")\n",
        "    return tf.strings.regex_replace(\n",
        "        stripped_html, \"[%s]\" % re.escape(string.punctuation), \"\"\n",
        "    )\n",
        "\n",
        "\n",
        "# Vocabulary size and number of words in a sequence.\n",
        "vocab_size = 10000\n",
        "sequence_length = 100\n",
        "\n",
        "# Use the text vectorization layer to normalize, split, and map strings to\n",
        "# integers. Note that the layer uses the custom standardization defined above.\n",
        "# Set maximum_sequence length as all samples are not of the same length.\n",
        "vectorize_layer = TextVectorization(\n",
        "    standardize=custom_standardization,\n",
        "    max_tokens=vocab_size,\n",
        "    output_mode=\"int\",\n",
        "    output_sequence_length=sequence_length,\n",
        ")\n",
        "\n",
        "# Make a text-only dataset (no labels) and call `Dataset.adapt` to build the\n",
        "# vocabulary.\n",
        "text_ds = train_ds.map(lambda x, y: x)\n",
        "vectorize_layer.adapt(text_ds)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zI9_wLIiWO8Z"
      },
      "source": [
        "## Create a classification model\n",
        "\n",
        "Use the [Keras Sequential API](https://www.tensorflow.org/guide/keras/sequential_model) to define the sentiment classification model. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "pHLcFtn5Wsqj"
      },
      "outputs": [],
      "source": [
        "embedding_dim = 16\n",
        "text_embedding = Embedding(vocab_size, embedding_dim, name=\"embedding\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iXAfZyEIRVY5"
      },
      "outputs": [],
      "source": [
        "text_input = tf.keras.Sequential(\n",
        "    [vectorize_layer, text_embedding], name=\"text_input\"\n",
        ")\n",
        "classifier_head = tf.keras.Sequential(\n",
        "    [GlobalAveragePooling1D(), Dense(16, activation=\"relu\"), Dense(1)],\n",
        "    name=\"classifier_head\",\n",
        ")\n",
        "\n",
        "model = tf.keras.Sequential([text_input, classifier_head])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "JjLNgKO7W2fe"
      },
      "source": [
        "## Compile and train the model"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "jpX9etB6IOQd"
      },
      "source": [
        "You will use [TensorBoard](https://www.tensorflow.org/tensorboard) to visualize metrics including loss and accuracy. Create a `tf.keras.callbacks.TensorBoard`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "W4Hg3IHFt4Px"
      },
      "outputs": [],
      "source": [
        "tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=\"logs\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7OrKAKAKIbuH"
      },
      "source": [
        "Compile and train the model using the `Adam` optimizer and `BinaryCrossentropy` loss. "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lCUgdP69Wzix"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=\"adam\",\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[\"accuracy\"],\n",
        ")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "5mQehiQyv8rP"
      },
      "outputs": [],
      "source": [
        "model.fit(\n",
        "    train_ds,\n",
        "    validation_data=val_ds,\n",
        "    epochs=15,\n",
        "    callbacks=[tensorboard_callback],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1wYnVedSPfmX"
      },
      "source": [
        "With this approach the model reaches a validation accuracy of around 85% \n",
        "\n",
        "Note: Your results may be a bit different, depending on how weights were randomly initialized before training the embedding layer. \n",
        "\n",
        "You can look into the model summary to learn more about each layer of the model."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mDCgjWyq_0dc"
      },
      "outputs": [],
      "source": [
        "model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hiQbOJZ2WBFY"
      },
      "source": [
        "Visualize the model metrics in TensorBoard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_Uanp2YH8RzU"
      },
      "outputs": [],
      "source": [
        "# docs_infra: no_execute\n",
        "%load_ext tensorboard\n",
        "%tensorboard --logdir logs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "NMtMv8yPEf5e"
      },
      "source": [
        "<!-- <img class=\"tfo-display-only-on-site\" src=\"https://tensorflow.org/tutorials/text/images/tensorboard-1.png\"/> -->"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IKp2PvLYI-r2"
      },
      "source": [
        "## Vocabulary remapping\n",
        "\n",
        "Now you're going to update the vocabulary and continue with warm-started training.\n",
        "\n",
        "First, get the base vocabulary and embedding matrix."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "HFgt2n6HJDAw"
      },
      "outputs": [],
      "source": [
        "embedding_weights_base = (\n",
        "    model.get_layer(\"text_input\").get_layer(\"embedding\").get_weights()[0]\n",
        ")\n",
        "vocab_base = vectorize_layer.get_vocabulary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A8wuaIVkJaNw"
      },
      "source": [
        "Define a new vectorization layer to generate a new bigger vocabulary"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_-YcdW4XJlcX"
      },
      "outputs": [],
      "source": [
        "# Vocabulary size and number of words in a sequence.\n",
        "vocab_size_new = 10200\n",
        "sequence_length = 100\n",
        "\n",
        "vectorize_layer_new = TextVectorization(\n",
        "    standardize=custom_standardization,\n",
        "    max_tokens=vocab_size_new,\n",
        "    output_mode=\"int\",\n",
        "    output_sequence_length=sequence_length,\n",
        ")\n",
        "\n",
        "# Make a text-only dataset (no labels) and call adapt to build the vocabulary.\n",
        "text_ds = train_ds.map(lambda x, y: x)\n",
        "vectorize_layer_new.adapt(text_ds)\n",
        "\n",
        "# Get the new vocabulary\n",
        "vocab_new = vectorize_layer_new.get_vocabulary()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lny782PFNF3j"
      },
      "outputs": [],
      "source": [
        "# View the new vocabulary tokens that weren't in `vocab_base`\n",
        "set(vocab_base) ^ set(vocab_new)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nHsDOlAnJrFH"
      },
      "source": [
        "Generate updated embeddings using the `keras.utils.warmstart_embedding_matrix` util."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "MgBlw3VnKrBL"
      },
      "outputs": [],
      "source": [
        "# Generate the updated embedding matrix\n",
        "updated_embedding = tf.keras.utils.warmstart_embedding_matrix(\n",
        "    base_vocabulary=vocab_base,\n",
        "    new_vocabulary=vocab_new,\n",
        "    base_embeddings=embedding_weights_base,\n",
        "    new_embeddings_initializer=\"uniform\",\n",
        ")\n",
        "# Update the model variable\n",
        "updated_embedding_variable = tf.Variable(updated_embedding)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEhm8fldKyR_"
      },
      "source": [
        "**OR**\n",
        "\n",
        "If you have an embedding matrix which you would like to initialize the new embedding matrix with, use `keras.initializers.Constant` as new_embeddings initializer. Copy the following block to a code cell to try this out.\n",
        "This would be helpful when you have a better embedding matrix initialization for new words in vocab.\n",
        "```\n",
        "# generate updated embedding matrix\n",
        "new_embedding = np.random.rand(len(vocab_new), 16)\n",
        "updated_embedding = tf.keras.utils.warmstart_embedding_matrix(\n",
        "            base_vocabulary=vocab_base,\n",
        "            new_vocabulary=vocab_new,\n",
        "            base_embeddings=embedding_weights_base,\n",
        "            new_embeddings_initializer=tf.keras.initializers.Constant(\n",
        "                new_embedding\n",
        "            )\n",
        "        )\n",
        "# update model variable\n",
        "updated_embedding_variable = tf.Variable(updated_embedding)\n",
        "```"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DbKLjXfhLUVa"
      },
      "source": [
        "Verify if the embedding matrix's shape has changed to reflect the new vocabulary."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "tYDrBBEtLWZQ"
      },
      "outputs": [],
      "source": [
        "updated_embedding_variable.shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iCd8LnSILZqk"
      },
      "source": [
        "Now that you have the updated embedding matrix, the next step is to update the layer weights."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i-PdukkBLlx1"
      },
      "outputs": [],
      "source": [
        "text_embedding_layer_new = Embedding(\n",
        "    vectorize_layer_new.vocabulary_size(), embedding_dim, name=\"embedding\"\n",
        ")\n",
        "text_embedding_layer_new.build(input_shape=[None])\n",
        "text_embedding_layer_new.embeddings.assign(updated_embedding)\n",
        "text_input_new = tf.keras.Sequential(\n",
        "    [vectorize_layer_new, text_embedding_layer_new], name=\"text_input_new\"\n",
        ")\n",
        "text_input_new.summary()\n",
        "\n",
        "# Verify the shape of updated weights\n",
        "# The new weights shape should reflect the new vocabulary size\n",
        "text_input_new.get_layer(\"embedding\").get_weights()[0].shape"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "juAdUZkVMpEj"
      },
      "source": [
        "Modify the model architecture to use the new text vectorization layer.\n",
        "\n",
        "You can also load the model from a checkpoint and update the model architecture as shown below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "etCo20sPNn2C"
      },
      "outputs": [],
      "source": [
        "warm_started_model = tf.keras.Sequential([text_input_new, classifier_head])\n",
        "warm_started_model.summary()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Hi4r5FubN202"
      },
      "source": [
        "You have successfully updated the model to accept a new vocabulary. The embedding layer is updated to map old vocabulary words to old embeddings and initialize embeddings for new vocabulary words to be learnt. The learned weights of the rest of the model will remain the same. The model is warm-started to continue to train from where it left off previously.\n",
        "\n",
        "You can now verify that the remapping worked. Get index of the vocabulary word \"the\" that is present both in base and new vocabulary and compare the embedding values. They should be equal."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vCdlWvpPPEow"
      },
      "outputs": [],
      "source": [
        "# New vocab words\n",
        "base_vocab_index = vectorize_layer(\"the\")[0]\n",
        "new_vocab_index = vectorize_layer_new(\"the\")[0]\n",
        "print(\n",
        "    warm_started_model.get_layer(\"text_input_new\").get_layer(\"embedding\")(\n",
        "        new_vocab_index\n",
        "    )\n",
        "    == embedding_weights_base[base_vocab_index]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "g1xX0XCEReRC"
      },
      "source": [
        "## Continue with warm-started training\n",
        "\n",
        "Notice how the training is warm-started. The accuracy of first epoch is around 85%. Close to the accuracy where the previous traning ended."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OtbXMQsTRdvq"
      },
      "outputs": [],
      "source": [
        "model.compile(\n",
        "    optimizer=\"adam\",\n",
        "    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),\n",
        "    metrics=[\"accuracy\"],\n",
        ")\n",
        "model.fit(\n",
        "    train_ds,\n",
        "    validation_data=val_ds,\n",
        "    epochs=15,\n",
        "    callbacks=[tensorboard_callback],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "5z67BhOZR6do"
      },
      "source": [
        "## Visualize warm-started training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eXPXUfw3QTY-"
      },
      "outputs": [],
      "source": [
        "# docs_infra: no_execute\n",
        "%reload_ext tensorboard\n",
        "%tensorboard --logdir logs"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a-9MOqehCQa8"
      },
      "source": [
        "<!-- <img class=\"tfo-display-only-on-site\" src=\"https://tensorflow.org/tutorials/text/images/tensorboard-2.png\"/> -->"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-SrmuEgJQIIP"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "In this tutorial you learned how to:\n",
        "\n",
        "* Train a sentiment classification model from scratch on a small vocabulary dataset.\n",
        "* Update the model architecture and warm start the embedding matrix when the vocabulary size changes.\n",
        "* Continuously improve model accuracy with expanding datasets\n",
        "\n",
        "To learn more about embeddings check out the [Word2Vec](https://www.tensorflow.org/tutorials/text/word2vec) and [Transformer model for language understanding](https://www.tensorflow.org/text/tutorials/transformer) tutorials."
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "warmstart_embedding_matrix.ipynb",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
